/**
 * @file snmp.c
 * @author weizhou
 * @brief SNMP Protocol
 * @version 0.1
 * @date 2024-03-11
 * 
 * @copyright Copyright (c) 2024 HYST Team
 * 
 */



#include "snmp.h"
#include "qelib.h"



#define SNMP_LOG_DOMAN      "snmp"
#define snmp_debug(...)      qelog_debug(SNMP_LOG_DOMAN, __VA_ARGS__)
#define snmp_info(...)       qelog_info(SNMP_LOG_DOMAN, __VA_ARGS__)
#define snmp_notice(...)     qelog_notice(SNMP_LOG_DOMAN, __VA_ARGS__)
#define snmp_warning(...)    qelog_warning(SNMP_LOG_DOMAN, __VA_ARGS__)
#define snmp_error(...)      qelog_error(SNMP_LOG_DOMAN, __VA_ARGS__)
#define snmp_fatal(...)      qelog_fatal(SNMP_LOG_DOMAN, __VA_ARGS__)
#define snmp_hexdump(...)    qehex_debug(SNMP_LOG_DOMAN, __VA_ARGS__)

#define OID_NOT_FOUND        (-1)



static qe_bool invalid_dtype(qe_u8 dtype)
{
    if ((dtype != SNMP_DTYPE_SEQUENCE) &&
        (dtype != SNMP_DTYPE_SEQUENCE_OF) && 
        (dtype != SNMP_DTYPE_INTEGER) &&
        (dtype != SNMP_DTYPE_OCTET_STRING) &&
        (dtype != SNMP_DTYPE_NULL_ITEM) &&
        (dtype != SNMP_DTYPE_OBJ_ID) &&
        (dtype != SNMP_DTYPE_COUNTER) &&
        (dtype != SNMP_DTYPE_GAUGE) &&
        (dtype != SNMP_DTYPE_TIME_TICKS) &&
        (dtype != SNMP_DTYPE_OPAQUE) &&
        (dtype != SNMP_PDU_GET_REQUEST) && 
        (dtype != SNMP_PDU_GET_NEXT_REQUEST) &&
        (dtype != SNMP_PDU_GET_RESPONSE) &&
        (dtype != SNMP_PDU_SET_REQUEST)) {
        return qe_true;
    }

    return qe_false;
}

static char *dtype_str(qe_u8 dtype)
{
    switch (dtype) {
    case SNMP_DTYPE_INTEGER: return "Integer";
    case SNMP_DTYPE_OCTET_STRING: return "String";
    case SNMP_DTYPE_NULL_ITEM: return "Null";
    case SNMP_DTYPE_OBJ_ID: return "Oid";
    case SNMP_DTYPE_SEQUENCE: return "Sequence";
    case SNMP_DTYPE_COUNTER: return "Counter";
    case SNMP_DTYPE_GAUGE: return "Gauge";
    case SNMP_DTYPE_TIME_TICKS: return "TimeTicks";
    case SNMP_DTYPE_OPAQUE: return "Opaque";
    case SNMP_PDU_GET_REQUEST: return "GetRequest";
    case SNMP_PDU_GET_NEXT_REQUEST: return "GetNextRequest";
    case SNMP_PDU_GET_RESPONSE: return "GetResponse";
    case SNMP_PDU_SET_REQUEST: return "SetRequest";
    default: return "Unknown";
    }
}

static qe_ret add_kv(snmp_parser *parser, snmp_kv *kv)
{
    snmp_kv *m = qe_malloc(sizeof(snmp_kv));
    qe_assert(m != QE_NULL);

    qe_memcpy(m, kv, sizeof(snmp_kv));
    qe_list_append(&m->list, &parser->values);
    return qe_ok;
}

static qe_uint get_tlv_llen(qe_uint len)
{
    if (len < 128) {
        return 1;
    } else if (len < 256) {
        return 2;
    } else if (len < 65536) {
        return 3;
    } else {
        return 4;
    }
}

static qe_uint get_tlv_size(snmp_tlv tlv)
{
    return 1 + get_tlv_llen(tlv.len) + tlv.len;
}

static void set_tlv(snmp_tlv *tlv, qe_u8 dtype, void *p, int len)
{
    tlv->dtype = dtype;
    tlv->len   = len;
    tlv->v     = p;
}

static qe_bool invalid_request_dtype(qe_u8 dtype)
{
    if ((dtype != SNMP_PDU_GET_REQUEST) &&
        (dtype != SNMP_PDU_GET_NEXT_REQUEST) &&
        (dtype != SNMP_PDU_SET_REQUEST)) {
        return qe_true;
    }

    return qe_false;
}

static void set_error(snmp_parser *parser, int index, int status)
{
    parser->erridx = index;
    parser->errsts = status;
}

static qe_ret get_entry_to_tlv(snmp_parser *parser, int index, snmp_tlv *tlv)
{
    snmp_data_entry *entry;

    if (!parser || !tlv || index>=parser->num_entrys) {
        snmp_error("invalid param");
        return qe_err_param;
    }

    entry = &parser->entrys[index];

    tlv->dtype = entry->dtype;

    switch (tlv->dtype) {

    case SNMP_DTYPE_OCTET_STRING:
    case SNMP_DTYPE_OBJ_ID:
        {
            if (entry->get) {
                entry->get(entry->v.string, (int *)&entry->dlen);
                snmp_debug("get string %d", entry->dlen);
                snmp_hexdump(entry->v.string, entry->dlen);
            }

            // if (tlv->dtype == SNMP_DTYPE_OCTET_STRING) {
            //     snmp_debug("string len %d", qe_strlen(entry->v.string));
            //     entry->dlen = qe_strlen(entry->v.string);
            // }

            tlv->len = entry->dlen;
            tlv->v = (qe_u8 *)entry->v.string;
        }
        break;
    
    case SNMP_DTYPE_INTEGER:
    case SNMP_DTYPE_TIME_TICKS:
    case SNMP_DTYPE_COUNTER:
    case SNMP_DTYPE_GAUGE:
        {
            if (entry->get) {
                entry->get(&entry->v.integer, (int *)&entry->dlen);
            }
            tlv->len = sizeof(qe_u32);
            tlv->v = (qe_u8 *)&entry->v.integer;
        }
        break;

    default:
        snmp_error("unknown dtype %d", tlv->dtype);
        return qe_err_param;
    }

    return qe_ok;
}

static qe_ret set_entry_form_tlv(snmp_parser *parser, int index, snmp_tlv *tlv)
{
    snmp_data_entry *entry;

    if (!parser || !tlv || index>=parser->num_entrys) {
        snmp_error("invalid param");
        return qe_err_param;
    }

    entry = &parser->entrys[index];

    if (entry->dtype != tlv->dtype) {
        snmp_error("dtype %d not match", tlv->dtype);
        set_error(parser, index, SNMP_ERROR_BAD_VALUE);
        return qe_err_param;
    }

    switch(entry->dtype) {

    case SNMP_DTYPE_OCTET_STRING:
    case SNMP_DTYPE_OBJ_ID:
        if (entry->set) {
            entry->set(tlv->v, tlv->len);
        }
        break;

    case SNMP_DTYPE_INTEGER:
    case SNMP_DTYPE_TIME_TICKS:
    case SNMP_DTYPE_COUNTER:
    case SNMP_DTYPE_GAUGE:
        if (entry->set) {
            entry->set(&tlv->integer, tlv->len);
        }
        break;

    default:
        snmp_error("unknown dtype %d", entry->dtype);
        return qe_err_param;
    }

    return qe_ok;
}

static int lookup_node(snmp_parser *parser, qe_u8 *oid, int len)
{
    int i;
    snmp_data_entry *entry;
    for (i=0; i<parser->num_entrys; i++) {
        entry = &parser->entrys[i];
        if (len == entry->oid_len) {
            if (qe_memcmp(entry->oid, oid, len) == 0)
                return i;
        }
    }

    return OID_NOT_FOUND;
}

qe_uint snmp_parse_length(qe_u8 *buf, qe_uint *len)
{
    qe_uint i = 1;

    if (buf[0] & 0x80) {
        qe_uint tlen = (buf[0] & 0x7F) - 1;
        *len = buf[i++];

        while (tlen--) {
            *len <<= 8;
            *len |= buf[i++];
        }
    } else {
        *len = buf[0];
    }

    return i;
}

qe_ret parse_tlv(snmp_message *msg, snmp_tlv *tlv)
{
    qe_uint llen;
    qe_u8 dtype;

    if (!msg || !tlv)
        return qe_err_param;

    dtype = msg->buf[msg->index];

    if (invalid_dtype(dtype)) {
        snmp_error("invalid dtype %d", dtype);
        return qe_err_param;
    }

    llen = snmp_parse_length(&msg->buf[msg->index+1], &tlv->len);

    tlv->s = &msg->buf[msg->index];
    tlv->v = tlv->s + llen + 1;
    tlv->dtype = dtype;

    switch (dtype) {
    case SNMP_DTYPE_SEQUENCE:
    case SNMP_PDU_GET_REQUEST:
    case SNMP_PDU_GET_NEXT_REQUEST:
    case SNMP_PDU_SET_REQUEST:
        tlv->n = tlv->v;
        break;
    default:
        tlv->n = tlv->v + tlv->len;
        break;
    }

    switch (dtype) {
    case SNMP_DTYPE_INTEGER:
    case SNMP_DTYPE_COUNTER:
    case SNMP_DTYPE_TIME_TICKS:
        if (tlv->len == 1) {
            tlv->integer = tlv->v[0];
        } else if (tlv->len == 2) {
            tlv->integer = (tlv->v[0] << 8) | tlv->v[1]; 
        } else if (tlv->len == 4) {
            tlv->integer = (tlv->v[0] << 24) | 
                           (tlv->v[1] << 16) |
                           (tlv->v[2] <<  8) |
                           (tlv->v[3] <<  0);
        }
        
        break;
    }

    msg->index += (tlv->n - tlv->s);

    return qe_ok;
}

static qe_ret parse_kv(snmp_parser *parser, snmp_message *msg)
{
    int index;
    qe_ret ret;
    snmp_kv kv;
    snmp_kv next;
    snmp_data_entry *entry;

    ret = parse_tlv(msg, &kv.seq);
    if (ret != qe_ok) {
        snmp_error("parse kv seq tlv error:%d", ret);
        return ret;
    }
    snmp_debug("sequence: %s len %d", dtype_str(kv.seq.dtype), kv.seq.len);

    if (kv.seq.dtype != SNMP_DTYPE_SEQUENCE) {
        snmp_error("kv sequence dtype %d error", kv.seq.dtype);
        return ret;
    }

    ret = parse_tlv(msg, &kv.key);
    if (ret != qe_ok) {
        snmp_error("parse kv key tlv error:%d", ret);
        return ret;
    }
    snmp_debug("key: %s len %d", dtype_str(kv.key.dtype), kv.key.len);

    if (kv.key.dtype != SNMP_DTYPE_OBJ_ID) {
        snmp_error("kv key dtype %d error", kv.key.dtype);
        return ret;
    }

    snmp_debug("Oid:");
    snmp_hexdump(kv.key.v, kv.key.len);

    index = lookup_node(parser, kv.key.v, kv.key.len);
    if (index < 0) {
        snmp_error("lookup oid error");
        set_error(parser, index, SNMP_ERROR_NO_SUCH_NAME);
        return ret;
    }
    snmp_debug("lookup node:%d", index);

    if (parser->request_type == SNMP_PDU_GET_NEXT_REQUEST) {
        index += 1;
        if (index >= parser->num_entrys) {
            snmp_error("request next index %d out of range %d", index, 
                parser->num_entrys - 1);
            set_error(parser, index, SNMP_ERROR_NO_SUCH_NAME);
            return qe_err_range;
        }
        entry = &parser->entrys[index];
        //set_tlv(&next.seq, kv.seq.dtype, kv.seq.v, kv.seq.len);
        //set_tlv(&next.key, kv.key.dtype, entry->oid, entry->oid_len);
        //get_entry_to_tlv(parser, index+1, &next.val);
        set_tlv(&kv.key, kv.key.dtype, entry->oid, entry->oid_len);
    }

    ret = parse_tlv(msg, &kv.val);
    if (ret != qe_ok) {
        snmp_error("parse kv val tlv error:%d", ret);
        return ret;
    }
    snmp_debug("val: %s len %d", dtype_str(kv.val.dtype), kv.val.len);

    if (parser->request_type == SNMP_PDU_GET_REQUEST || 
        parser->request_type == SNMP_PDU_GET_NEXT_REQUEST) {
        snmp_debug("get %d to tlv", index);
        get_entry_to_tlv(parser, index, &kv.val);
    } else if (parser->request_type == SNMP_PDU_SET_REQUEST) {
        snmp_debug("set %d to tlv", index);
        snmp_debug("value:");
        snmp_hexdump(kv.val.v, kv.val.len);
        set_entry_form_tlv(parser, index, &kv.val);
    }

    kv.seq.len = get_tlv_size(kv.key) + get_tlv_size(kv.val);
    snmp_debug("response sequence len %d", kv.seq.len);

    add_kv(parser, &kv);

    parser->sequence.len += get_tlv_size(kv.seq);
    snmp_debug("response sequence of len %d", parser->sequence.len);

    return qe_ok;
}

static qe_ret parse_sequence_of(snmp_parser *parser, snmp_message *msg)
{
    qe_ret ret;

    /* Parse sequence of */
    ret = parse_tlv(msg, &parser->sequence);
    if (ret != qe_ok) {
        snmp_error("parse sequence tlv error:%d", ret);
        return ret;
    }
    snmp_debug("sequence of: %s len %d", 
        dtype_str(parser->sequence.dtype),
        parser->sequence.len);

    if (parser->sequence.dtype != SNMP_DTYPE_SEQUENCE_OF) {
        snmp_error("sequence dtype %d error", parser->sequence.dtype);
        return ret;
    }

    parser->sequence.len = 0;

    while (msg->index < parser->top.len) {
        ret = parse_kv(parser, msg);
        if (ret != qe_ok) {
            snmp_error("parse kv error:%d", ret);
            return ret;
        }
    }
    snmp_debug("parse kv finish");
    snmp_debug("response sequence of len:%d", parser->sequence.len);

    return ret;
}

static qe_ret parse_request(snmp_parser *parser, snmp_message *msg)
{
    qe_ret ret;

    /* Parse request */
    ret = parse_tlv(msg, &parser->request);
    if (ret != qe_ok) {
        snmp_error("parse request tlv error:%d", ret);
        return ret;
    }
    snmp_debug("request: %s len %d", dtype_str(parser->request.dtype),
        parser->request.len);
    parser->request_type = parser->request.dtype;

    if (invalid_request_dtype(parser->request.dtype)) {
        snmp_error("invalid request dtype:%d", parser->request.dtype);
        return qe_err_notsupport;
    }

    /* Parse request id */
    ret = parse_tlv(msg, &parser->request_id);
    if (ret != qe_ok) {
        snmp_error("parse request id tlv error:%d", ret);
        return ret;
    }
    snmp_debug("request id: %s %d", dtype_str(parser->request_id.dtype),
        parser->request_id.integer);

    /* Parse error status */
    ret = parse_tlv(msg, &parser->error_status);
    if (ret != qe_ok) {
        snmp_error("parse error status tlv error:%d", ret);
        return ret;
    }
    snmp_debug("error status: %s %d %d", 
        dtype_str(parser->error_status.dtype),
        parser->error_status.len,
        parser->error_status.integer);

    /* Parse error index */
    ret = parse_tlv(msg, &parser->error_index);
    if (ret != qe_ok) {
        snmp_error("parse error status index error:%d", ret);
        return ret;
    }
    snmp_debug("error index: %s %d",
        dtype_str(parser->error_index.dtype),
        parser->error_index.integer);

    /* Parse sequence of */
    ret = parse_sequence_of(parser, msg);
    if (ret != qe_ok) {
        snmp_error("parse sequence of error:%d", ret);
    }

    parser->request.len = get_tlv_size(parser->request_id) + 
                          get_tlv_size(parser->error_status) + 
                          get_tlv_size(parser->error_index) + 
                          get_tlv_size(parser->sequence);
    snmp_debug("response request len:%d", parser->request.len);

    return ret;
}

qe_ret snmp_message_parse(snmp_parser *parser, snmp_message *msg)
{
    qe_ret ret;

    msg->index = 0;
    set_error(parser, 0, 0);

    /* Parse top tlv */
    ret = parse_tlv(msg, &parser->top);
    if (ret != qe_ok) {
        snmp_error("parse top tlv err:%d", ret);
        return ret;
    }
    snmp_debug("top: %s len %d", dtype_str(parser->top.dtype), parser->top.len);

    if (parser->top.dtype != SNMP_DTYPE_SEQUENCE) {
        snmp_error("top tlv type %d err", parser->top.dtype);
        return qe_err_notsupport;
    }

    /* Parse version */
    ret = parse_tlv(msg, &parser->version);
    if (ret != qe_ok) {
        snmp_error("parse version err:%d", ret);
        return ret;
    }
    snmp_debug("version: %s %d", dtype_str(parser->version.dtype),
        parser->version.integer);

    if (!(parser->version.dtype == SNMP_DTYPE_INTEGER && 
        parser->version.v[0] == SNMP_V1)) {
        snmp_error("version error %d", parser->version.v[0]);
        return qe_err_notsupport;
    }

    /* Parse community */
    ret = parse_tlv(msg, &parser->community);
    if (ret != qe_ok) {
        snmp_error("parse community err:%d", ret);
        return ret;
    }
    snmp_debug("community: %s %.*s", dtype_str(parser->community.dtype),
        parser->community.len, parser->community.v);

    if (!(parser->community.dtype == SNMP_DTYPE_OCTET_STRING &&
         parser->community.len == SNMP_COMMUNITY_SIZE)) {
        snmp_error("community %d %d error", parser->community.dtype, 
            parser->community.len);
        return qe_err_notsupport;
    }

    if (qe_memcmp(parser->community.v, SNMP_COMMUNITY, 
        SNMP_COMMUNITY_SIZE) != 0) {
        snmp_error("unknown community %s", parser->community.v);
        return qe_err_notsupport;
    }

    ret = parse_request(parser, msg);
    if (ret != qe_ok) {
        snmp_error("request process error:%d", ret);
    }

    parser->top.len = get_tlv_size(parser->version) + 
                      get_tlv_size(parser->community) + 
                      get_tlv_size(parser->request);
    snmp_debug("response top len:%d", parser->top.len);

    return ret;
}

static qe_ret package_tlv_internal(snmp_message *msg, int pos, snmp_tlv *tlv, 
    qe_bool update_index)
{
    int llen = 0;
    int shift = 0;

    if (!msg || !tlv)
        return qe_err_param;

    /* package dtype length*/
    msg->buf[pos++] = tlv->dtype;

    /* package length*/
    if (tlv->len < 128) {
        msg->buf[pos++] = tlv->len;
    } else {
        if (tlv->len < 0x100) {
            llen = 1;
        } else if (tlv->len < 0x10000) {
            llen = 2;
        } else if (tlv->len < 0x1000000) {
            llen = 3;
        }
        snmp_debug("buf[%d] 0x%x llen:%d", pos, 0x80 | llen, llen);
        msg->buf[pos++] = 0x80 | llen;
        while (llen--) {
            msg->buf[pos+llen] = (tlv->len >> shift) & 0xFF;
            snmp_debug("buf[%d] 0x%x", pos+llen, msg->buf[pos+llen]);
            shift += 8;
        }
        pos += shift / 8;
    }

    /* package value */
    if ((tlv->dtype != SNMP_DTYPE_SEQUENCE) && 
        (tlv->dtype != SNMP_DTYPE_SEQUENCE_OF) &&
        (tlv->dtype != SNMP_PDU_GET_NEXT_REQUEST) &&
        (tlv->dtype != SNMP_PDU_GET_REQUEST) && 
        (tlv->dtype != SNMP_PDU_GET_RESPONSE) &&
        (tlv->dtype != SNMP_PDU_SET_REQUEST)) {
        qe_memcpy(&msg->buf[pos], tlv->v, tlv->len);
        pos += tlv->len;
    }

    if (update_index)
        msg->index = pos;

    return qe_ok;
}

static qe_ret package_tlv(snmp_message *msg, snmp_tlv *tlv)
{
    return package_tlv_internal(msg, msg->index, tlv, qe_true);
}

static qe_ret package_tlv_pos(snmp_message *msg, int pos, snmp_tlv *tlv)
{
    return package_tlv_internal(msg, pos, tlv, qe_false);
}

static void package_kv(snmp_message *msg, snmp_kv *kv)
{
    /* package sequence */
    snmp_debug("sequence index:%d", msg->index);
    package_tlv(msg, &kv->seq);

    /* package key */
    snmp_debug("key index:%d", msg->index);
    package_tlv(msg, &kv->key);

    /* package val */
    snmp_debug("val index:%d", msg->index);
    package_tlv(msg, &kv->val);
}

qe_ret snmp_package_response(snmp_parser *parser, snmp_message *msg)
{
    int top_index = 0;
    int request_index = 0;
    int sequence_of_index = 0;
    qe_list *node, *tmp;
    snmp_kv *kv;

    msg->index = 0;

    set_tlv(&parser->error_status, SNMP_DTYPE_INTEGER, 
        &parser->errsts, 1);
    set_tlv(&parser->error_index, SNMP_DTYPE_INTEGER, 
        &parser->erridx, 1);

    /* package top */
    snmp_debug("top index:%d %d", msg->index, parser->top.len);
    package_tlv(msg, &parser->top);

    /* package version */
    snmp_debug("version index:%d", msg->index);
    package_tlv(msg, &parser->version);

    /* package community */
    snmp_debug("community index:%d", msg->index);
    package_tlv(msg, &parser->community);

    /* package request */
    parser->request.dtype = SNMP_PDU_GET_RESPONSE;
    snmp_debug("request index:%d", msg->index);
    package_tlv(msg, &parser->request);

    /* package request id */
    snmp_debug("request id index:%d", msg->index);
    package_tlv(msg, &parser->request_id);

    /* package error status */
    snmp_debug("error status index:%d %d", msg->index, parser->error_status.len);
    package_tlv(msg, &parser->error_status);

    /* package error index */
    snmp_debug("error index index:%d", msg->index);
    package_tlv(msg, &parser->error_index);

    /* package sequence of */
    snmp_debug("sequence of index:%d", sequence_of_index);
    package_tlv(msg, &parser->sequence);

    qe_list_foreach_safe(node, tmp, &parser->values) {
        kv = qe_list_entry(node, snmp_kv, list);
        package_kv(msg, kv);
        qe_list_remove(&kv->list);
        qe_free(kv);
    }
    snmp_debug("after kv, index:%d", msg->index);

    snmp_hexdump(msg->buf, msg->index);

    return qe_ok;
}

snmp_message *snmp_message_new(qe_uint size)
{
    snmp_message *msg = qe_malloc(sizeof(snmp_message) + size);
    if (!msg)
        return QE_NULL;

    msg->index = 0;
    msg->size  = size;
    //msg->buf = ((qe_u8 *)msg) + sizeof(snmp_message);
    return msg;
}
